# Standard Library Imports
import datetime
import logging as log
import os
import random

# Third-Party Imports
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision import transforms
from scipy.ndimage import gaussian_filter
from skimage import measure
from sklearn.metrics import auc


def set_seed(seed):
    random.seed(seed)  # Set Python random seed
    np.random.seed(seed)  # Set NumPy random seed
    torch.manual_seed(seed)  # Set PyTorch random seed
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)  # Set CUDA random seed
        torch.cuda.manual_seed_all(seed)  # Set all CUDA devices' random seed


def setup_logger(name, device_id, base_dir="./log", level=log.INFO):
    logger = log.getLogger(name)
    logger.setLevel(level)
    logger.handlers.clear()

    now = datetime.datetime.now()
    folder_name = now.strftime("%m-%d")
    os.makedirs(os.path.join(base_dir, folder_name), exist_ok=True)

    log_filename = now.strftime(f"%H-%M-%S_{device_id}.txt")
    txt_path = os.path.join(base_dir, folder_name, log_filename)

    formatter = log.Formatter(
        "%(asctime)s - %(levelname)s: %(message)s",
        datefmt="%y-%m-%d %H:%M:%S"
    )

    file_handler = log.FileHandler(txt_path, mode="w")
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    console_handler = log.StreamHandler()
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)

    return logger


def rescale(x):
    return (x - x.min()) / (x.max() - x.min())


def cal_pro_score(labeled_imgs, score_imgs, fpr_thresh=0.3, max_steps=200):
    labeled_imgs = np.array(labeled_imgs)
    labeled_imgs[labeled_imgs <= 0.45] = 0
    labeled_imgs[labeled_imgs > 0.45] = 1
    labeled_imgs = labeled_imgs.astype(bool)

    max_th = score_imgs.max()
    min_th = score_imgs.min()
    delta = (max_th - min_th) / max_steps

    ious_mean = []
    ious_std = []
    pros_mean = []
    pros_std = []
    threds = []
    fprs = []
    binary_score_maps = np.zeros_like(score_imgs, dtype=bool)
    for step in range(max_steps):
        thred = max_th - step * delta
        # segmentation
        binary_score_maps[score_imgs <= thred] = 0
        binary_score_maps[score_imgs > thred] = 1

        pro = []  # per region overlap
        iou = []  # per image iou
        # pro: find each connected gt region, compute the overlapped pixels between the gt region and predicted region
        # iou: for each image, compute the ratio, i.e. intersection/union between the gt and predicted binary map
        for i in range(len(binary_score_maps)):  # for i th image
            # pro (per region level)
            label_map = measure.label(labeled_imgs[i], connectivity=2)
            props = measure.regionprops(label_map)
            for prop in props:
                x_min, y_min, x_max, y_max = prop.bbox
                cropped_pred_label = binary_score_maps[i][x_min:x_max, y_min:y_max]
                # cropped_mask = masks[i][x_min:x_max, y_min:y_max]
                cropped_mask = prop.filled_image  # corrected!
                intersection = np.logical_and(cropped_pred_label, cropped_mask).astype(np.float32).sum()
                pro.append(intersection / prop.area)
            # iou (per image level)
            intersection = np.logical_and(binary_score_maps[i], labeled_imgs[i]).astype(np.float32).sum()
            union = np.logical_or(binary_score_maps[i], labeled_imgs[i]).astype(np.float32).sum()
            if labeled_imgs[i].any() > 0:  # when the gt have no anomaly pixels, skip it
                iou.append(intersection / union)
        # against steps and average metrics on the testing data
        ious_mean.append(np.array(iou).mean())
        #             print("per image mean iou:", np.array(iou).mean())
        ious_std.append(np.array(iou).std())
        pros_mean.append(np.array(pro).mean())
        pros_std.append(np.array(pro).std())
        # fpr for pro-auc
        masks_neg = ~labeled_imgs
        fpr = np.logical_and(masks_neg, binary_score_maps).sum() / masks_neg.sum()
        fprs.append(fpr)
        threds.append(thred)

    # as array
    threds = np.array(threds)
    pros_mean = np.array(pros_mean)
    pros_std = np.array(pros_std)
    fprs = np.array(fprs)

    # default 30% fpr vs pro, pro_auc
    idx = fprs <= fpr_thresh  # find the indexs of fprs that is less than expect_fpr (default 0.3)
    fprs_selected = fprs[idx]
    fprs_selected = rescale(fprs_selected)  # rescale fpr [0,0.3] -> [0, 1]
    pros_mean_selected = pros_mean[idx]
    pro_auc_score = auc(fprs_selected, pros_mean_selected)
    # print("pro auc ({}% FPR):".format(int(expect_fpr * 100)), pro_auc_score)
    return pro_auc_score
    

def get_new_size(size, target_short_edge=240):
    width, height = size
    scale_factor = target_short_edge / min(height, width)

    new_width = int(width * scale_factor)
    new_height = int(height * scale_factor)

    return new_width, new_height


def resize_tensor(tensor, new_width, new_height):
    tensor = tensor.unsqueeze(0).unsqueeze(0)

    resized_tensor = F.interpolate(
        tensor,
        size=(new_height, new_width),
        mode='bilinear',
        align_corners=False
    )
    return resized_tensor.squeeze(0).squeeze(0)


def transform_image(image, new_width, new_height):
    transform = [
        transforms.Resize(size=(new_height, new_width), interpolation=transforms.InterpolationMode.BICUBIC, max_size=None, antialias=None),
        transforms.CenterCrop(size=(240, 240)),
        _convert_to_rgb,
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),
    ]

    transform = transforms.Compose(transform)
    return transform(image)


def _convert_to_rgb(image):
    return image.convert('RGB')


def apply_gaussian_blur(score_map, sigma=4.0):
    score_map = score_map.cpu().numpy()
    score_map = gaussian_filter(score_map, sigma=sigma)
    score_map = torch.tensor(score_map, dtype=torch.float32)
    return score_map


def adjust_scale(source_tensor, target_tensor):
    source_min = source_tensor.min()
    source_max = source_tensor.max()
    
    target_min = target_tensor.min()
    target_max = target_tensor.max()

    normalized = (source_tensor - source_min) / (source_max - source_min)
    adjusted_tensor = normalized * (target_max - target_min) + target_min
    return adjusted_tensor
